Flex Checkpoint工作记录
1. Flex Checkpoint关键组件
1.1 reshard_sharded_state_dict
def reshard_sharded_state_dict(
src_sharded_state_dict: ShardedStateDict,
dst_sharded_state_dict: ShardedStateDict,
process_group: Group,
coordinator_rank: int | None = 0,
offload: bool | None = False,
aoa_config: dist[str, list[str]] | None = None,
) -> None:
local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}
global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)
src_state_dict_shard_info = {}
for rank_shard_info in global_src_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in src_state_dict_shard_info:
src_state_dict_shard_info[key] = []
src_state_dict_shard_info[key].append(tensor_shard_info)
# check validity
check_src_state_dict_validity(src_state_dict_shard_info)
local_dst_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in dst_sharded_state_dict.items()
}
global_dst_state_dict_shard_info = []
dist.all_gather_object(
global_dst_state_dict_shard_info,
local_dst_state_dict_shard_info,
group=process_group,
)
dst_state_dict_shard_info = {}
for rank_shard_info in global_dst_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in dst_state_dict_shard_info:
dst_state_dict_shard_info[key] = []
dst_state_dict_shard_info[key].append(tensor_shard_info)
# check validity
check_dst_state_dict_validity(dst_state_dict_shard_info)
check_src_dst_state_dict_validity(
src_state_dict_shard_info, dst_state_dict_shard_info
)
# build metadata
state_dict_metadata = {
tensor_name: [
LocalTensorMetadata(
global_offset=shard_info[0],
local_shape=shard_info[1],
dtype=shard_info[2],
)
for shard_info in shard_infos
]
for tensor_name, shard_infos in src_state_dict_shard_info.items()
}
virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}
global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)
# Merge storage metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)
# Prepare metadata for loading
metadata = Metadata(
state_dict_metadata=state_dict_metadata,
storage_metadata=storage_metadata,
flat_mapping=None,
)
# Extract local tensors
src_state_dict = {
key: value.local_tensor for key, value in src_sharded_state_dict.items()
}
dst_state_dict = dst_sharded_state_dict
# reshard using _load_state_dict
_load_state_dict(
target_state_dict=dst_state_dict,
source_state_dict={virtual_file_path: src_state_dict},
metadata_list=[metadata],
coordinator_rank=coordinator_rank,
process_group=process_group,
offload=offload,
)
这个函数实际是为了构建reshard过程中需要的metadata,实际的reshard操作,在load_state_dict里面。state_dict_metadata 和 storage_metadata 最终都包含了所有 rank 的分片信息,是全局的完整信息。
这里使用virtual_file_path是因为此时实际的数据已经可以取到,即每个rank上local_tensor的实际值,无需再从文件中读取,这么做是为了整个格式上的对齐。
1.1.1 全局信息的构建过程
state_dict_metadata 的构建,state_dict_metadata用来保存Tensor的全局元数据信息
# 步骤1:每个 rank 收集自己的分片信息
local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}
# 步骤2:全局收集所有 rank 的信息
global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)
# 结果:每个 rank 都有所有 rank 的信息
global_src_state_dict_shard_info = [
# rank 0 的信息
{"linear.weight": ((0, 0), (256, 512), "float32", (1024, 512), False)},
# rank 1 的信息
{"linear.weight": ((256, 0), (256, 512), "float32", (1024, 512), False)},
# rank 2 的信息
{"linear.weight": ((512, 0), (256, 512), "float32", (1024, 512), False)},
# rank 3 的信息
{"linear.weight": ((768, 0), (256, 512), "float32", (1024, 512), False)},
]
# 步骤3:重组为按张量分组的全局信息
src_state_dict_shard_info = {
"linear.weight": [
((0, 0), (256, 512), "float32", (1024, 512), False), # rank 0
((256, 0), (256, 512), "float32", (1024, 512), False), # rank 1
((512, 0), (256, 512), "float32", (1024, 512), False), # rank 2
((768, 0), (256, 512), "float32", (1024, 512), False), # rank 3
]
}
# 步骤4:构建全局的 state_dict_metadata
state_dict_metadata = {
"linear.weight": [
LocalTensorMetadata(global_offset=(0, 0), local_shape=(256, 512), dtype="float32"), # rank 0
LocalTensorMetadata(global_offset=(256, 0), local_shape=(256, 512), dtype="float32"), # rank 1
LocalTensorMetadata(global_offset=(512, 0), local_shape=(256, 512), dtype="float32"), # rank 2
LocalTensorMetadata(global_offset=(768, 0), local_shape=(256, 512), dtype="float32"), # rank 3
]
}
storage_metadata 的构建,storage_metadata 用来保存Tensor实际数据保存的位置信息
# 步骤1:每个 rank 构建自己的存储映射
virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}
# rank 0 的本地映射
local_storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0",
}
# 步骤2:全局收集所有 rank 的存储映射
global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)
# 结果:每个 rank 都有所有 rank 的存储映射
global_storage_metadata = [
# rank 0 的映射
{LocalTensorIndex("linear.weight", (0, 0)): "vfile_0"},
# rank 1 的映射
{LocalTensorIndex("linear.weight", (256, 0)): "vfile_1"},
# rank 2 的映射
{LocalTensorIndex("linear.weight", (512, 0)): "vfile_2"},
# rank 3 的映射
{LocalTensorIndex("linear.weight", (768, 0)): "vfile_3"},
]
# 步骤3:合并为全局的 storage_metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)
# 最终的全局 storage_metadata
storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0", # rank 0
LocalTensorIndex("linear.weight", (256, 0)): "vfile_1", # rank 1
LocalTensorIndex("linear.weight", (512, 0)): "vfile_2", # rank 2
LocalTensorIndex("linear.weight", (768, 0)): "vfile_3", # rank 3
}
1.1.2 为什么需要全局信息?
重分片需要完整的分片信息
# 重分片过程:
# 源:4 个分片 -> 目标:2 个分片
# 需要知道所有源分片的信息才能正确重分片
source_shards = [
((0, 0), (256, 512)), # rank 0
((256, 0), (256, 512)), # rank 1
((512, 0), (256, 512)), # rank 2
((768, 0), (256, 512)), # rank 3
]
# 目标分片需要从多个源分片组合数据
target_shard_0 = combine(source_shards[0], source_shards[1]) # 需要 rank 0 和 rank 1 的数据
target_shard_1 = combine(source_shards[2], source_shards[3]) # 需要 rank 2 和 rank 3 的数据
** 数据访问需要全局映射**
# _load_state_dict 需要知道:
# 1. 每个分片在哪里(storage_metadata)
# 2. 每个分片的形状和位置(state_dict_metadata)
def load_shard(tensor_name, global_offset):
# 根据全局信息找到对应的分片
index = LocalTensorIndex(tensor_name, global_offset)
file_path = storage_metadata[index] # "vfile_0"
# 从对应的数据源获取数据
if file_path in source_state_dict:
return source_state_dict[file_path][tensor_name]
** 验证需要全局视图**
# 验证分片完整性需要全局信息
def validate_completeness():
# 检查是否所有分片都存在
expected_shards = [
(0, 0), (256, 0), (512, 0), (768, 0)
]
for offset in expected_shards:
index = LocalTensorIndex("linear.weight", offset)
if index not in storage_metadata:
raise ValueError(f"Missing shard at {offset}")
1.2 utils相关工具组件总结
1. 索引转换工具
ravel_index(indices, shape)
def ravel_index(indices, shape):
idx = 0
for i, dim in zip(indices, shape):
idx = idx * dim + i
return idx
作用:将多维索引转换为线性索引(行优先顺序)
详细解释:
# 例子:shape = (2, 3, 4)
# 多维索引 (1, 2, 3) 转换为线性索引
# 计算过程:
# i=0: idx = 0 * 2 + 1 = 1
# i=1: idx = 1 * 3 + 2 = 5
# i=2: idx = 5 * 4 + 3 = 23
# 结果:线性索引 = 23
# 验证:在2×3×4的张量中,位置(1,2,3)的线性索引确实是23
应用场景:
- 将多维张量的位置转换为内存中的线性地址
- 在分片计算中定位元素在全局张量中的位置
unravel_index(idx, shape)
def unravel_index(idx, shape):
indices = []
for dim in reversed(shape):
indices.append(idx % dim)
idx //= dim
return tuple(reversed(indices))
作用:将线性索引转换为多维索引
详细解释:
# 例子:shape = (2, 3, 4), idx = 23
# 线性索引 23 转换为多维索引
# 计算过程(从右到左):
# dim=4: indices.append(23 % 4 = 3), idx = 23 // 4 = 5
# dim=3: indices.append(5 % 3 = 2), idx = 5 // 3 = 1
# dim=2: indices.append(1 % 2 = 1), idx = 1 // 2 = 0
# 结果:多维索引 = (1, 2, 3)
应用场景:
- 从内存地址恢复多维张量的位置
- 在分片重建时确定元素在全局张量中的坐标
2. 切片计算工具
minimal_nd_slice(shape, flat_start, flat_end)
def minimal_nd_slice(shape, flat_start, flat_end):
start_idx = unravel_index(flat_start, shape)
end_idx = unravel_index(flat_end - 1, shape)
min_slices = []
for axis in range(len(shape)):
if axis == 0:
s = start_idx[axis]
e = end_idx[axis] + 1
else:
if start_idx[axis - 1] == end_idx[axis - 1]:
s = min(start_idx[axis], end_idx[axis])
e = max(start_idx[axis], end_idx[axis]) + 1
else:
s = 0
e = shape[axis]
min_slices.append((s, e))
return min_slices, start_idx, end_idx
作用:计算包含给定扁平化范围的最小N维切片
详细解释:
# 例子:shape = (4, 3), flat_start = 5, flat_end = 8
# 扁平化范围 [5, 8) 转换为最小切片
# 计算过程:
# start_idx = unravel_index(5, (4, 3)) = (1, 2)
# end_idx = unravel_index(7, (4, 3)) = (2, 1)
# 对于axis=0:
# s = 1, e = 2 + 1 = 3
# 对于axis=1:
# start_idx[0] = 1, end_idx[0] = 2, 不相等
# 所以 s = 0, e = 3
# 结果:min_slices = [(1, 3), (0, 3)]
# 这表示需要切片 [1:3, 0:3]
应用场景:
- 将扁平化的索引范围转换为最优的多维切片
- 减少数据传输量,提高效率
flat_range_in_min_slice(shape, min_slices, flat_start, flat_end)
def flat_range_in_min_slice(shape, min_slices, flat_start, flat_end):
min_starts = tuple(s[0] for s in min_slices)
min_flat_start = ravel_index(min_starts, shape)
return flat_start - min_flat_start, flat_end - min_flat_start
作用:计算在最小切片中的相对扁平化范围
详细解释:
# 例子:shape = (4, 3), min_slices = [(1, 3), (0, 3)]
# flat_start = 5, flat_end = 8
# 计算过程:
# min_starts = (1, 0)
# min_flat_start = ravel_index((1, 0), (4, 3)) = 3
# 相对范围 = (5 - 3, 8 - 3) = (2, 5)
# 这表示在最小切片内的相对位置
应用场景:
- 计算在切片内的相对偏移
- 用于精确的数据提取和复制
3. 状态字典检查工具
is_sharded_state_dict(o)
def is_sharded_state_dict(o):
if not isinstance(o, dict):
return False
values = list(o.values())
has_sharded_weight = any(isinstance(v, ShardedWeight) for v in values)
if has_sharded_weight:
if not all(isinstance(v, ShardedWeight) for v in values):
raise TypeError(
"All values must be ShardedWeight if any value is ShardedWeight."
)
return True
else:
return False
作用:检查字典是否为分片状态字典
详细解释:
# 检查规则:
# 1. 必须是字典类型
# 2. 如果任何值是ShardedWeight,则所有值都必须是ShardedWeight
# 3. 不允许混合类型
# 例子:
valid_dict = {
"weight": ShardedWeight(...),
"bias": ShardedWeight(...)
} # 返回 True
invalid_dict = {
"weight": ShardedWeight(...),
"bias": paddle.Tensor(...)
} # 抛出TypeError
应用场景:
- 验证检查点格式的正确性
- 确保状态字典的一致性
4. 重叠区域计算工具
get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape)
def get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape):
ndim = len(desc_offset)
overlap_offset = []
overlap_shape = []
desc_starts = []
shard_starts = []
for i in range(ndim):
desc_lo = desc_offset[i]
desc_hi = desc_offset[i] + desc_shape[i]
shard_lo = shard_offset[i]
shard_hi = shard_offset[i] + shard_shape[i]
# overlap
lo = max(desc_lo, shard_lo)
hi = min(desc_hi, shard_hi)
if lo >= hi:
return False, None, None, None, None
overlap_offset.append(lo)
overlap_shape.append(hi - lo)
desc_starts.append(lo - desc_lo)
shard_starts.append(lo - shard_lo)
return True, overlap_offset, overlap_shape, desc_starts, shard_starts
作用:计算两个分片之间的重叠区域
详细解释:
# 例子:2D张量
# desc: offset=(0,0), shape=(4,4)
# shard: offset=(2,2), shape=(4,4)
# 计算过程:
# 维度0:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[0] = 2, overlap_shape[0] = 2
# desc_starts[0] = 2-0 = 2, shard_starts[0] = 2-2 = 0
# 维度1:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[1] = 2, overlap_shape[1] = 2
# desc_starts[1] = 2-0 = 2, shard_starts[1] = 2-2 = 0
# 结果:
# 重叠区域:offset=(2,2), shape=(2,2)
# 在desc中的起始:(2,2)
# 在shard中 的起始:(0,0)
应用场景:
- 计算不同分片策略间的数据重叠
- 为数据复制提供精确的范围信息
5. 分片数据复制工具
assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard)
def assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard):
# 1. 计算源分片的重叠区域
src_has, _, overlap_shape, src_desc_starts, src_shard_starts = (
get_overlap_region(
src_desc.global_offset,
src_desc.local_shape,
src_shard.global_offset,
src_shard.local_shape,
)
)
# 2. 计算目标分片的重叠区域
dst_has, _, overlap_shape2, dst_desc_starts, dst_shard_starts = (
get_overlap_region(
dst_desc.global_offset,
dst_desc.local_shape,
dst_shard.global_offset,
dst_shard.local_shape,
)
)
# 3. 验证重叠区域一致性
assert src_has or dst_has, "no overlap!"
assert overlap_shape == overlap_shape2, "overlap shape mismatch!"
# 4. 执行数据复制
axes = list(range(len(overlap_shape)))
src_tensor_slice = paddle.slice(
src_shard.local_tensor,
axes=axes,
starts=src_shard_starts,
ends=[s + o for s, o in zip(src_shard_starts, overlap_shape)],
)
dst_tensor_slice = paddle.slice(
dst_shard.local_tensor,
axes=axes,
starts=dst_shard_starts,
ends=[s + o for s, o in zip(dst_shard_starts, overlap_shape)],
)
paddle.assign(src_tensor_slice, dst_tensor_slice)
作用:在不同分片间复制重叠数据
详细解释:
# 完整流程:
# 1. 计算源分片与描述符的重叠区域
# 2. 计算目标分片与描述符的重叠区域
# 3. 验证两个重叠区域的一致性
# 4. 从源分片提取重叠部分
# 5. 复制到目标分片
# 例子:从tp2转换到tp4
# src_desc: 描述tp2时的分片布局
# src_shard: tp2时的实际数据
# dst_desc: 描述tp4时的分片布局
# dst_shard: tp4时的目标数据
# 函数会:
# 1. 找到tp2和tp4分片的重叠部分
# 2. 将tp2的重叠数据复制到tp4的对应位置
应用场景:
- 分片策略转换时的数据重分布
- 检查点加载时的数据恢复
6. 信息合并工具
merge_shard_info_list(list_of_dicts)
def merge_shard_info_list(list_of_dicts):
merged = defaultdict(list)
for info in list_of_dicts:
for k, v in info.items():
merged[k].extend(v)
return dict(merged)
作用:合并多个分片信息字典
详细解释:
# 例子:
list_of_dicts = [
{"param1": [info1, info2]},
{"param1": [info3], "param2": [info4]},
{"param2": [info5, info6]}
]
# 合并结果:
merged = {
"param1": [info1, info2, info3],
"param2": [info4, info5, info6]
}
应用场景:
- 收集所有rank的分片信息
- 构建全局的分片视图
7. 描述符构建工具
build_shard_desc(val)
def build_shard_desc(val):
return ShardedWeightDesc(
key=val.key,
local_shape=tuple(val.local_shape),
global_shape=tuple(val.global_shape),
global_offset=tuple(val.global_offset),
)
作用:从ShardedWeight构建ShardedWeightDesc
详细解释:
# 转换过程:
# 输入:ShardedWeight对象(包含实际数据)
# 输出:ShardedWeightDesc对象(仅包含元数据)
# 例子:
sharded_weight = ShardedWeight(
key="linear.weight",
local_tensor=paddle.Tensor(...), # 实际数据
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)
# 转换为:
shard_desc = ShardedWeightDesc(
key="linear.weight",
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)
应用场景:
- 提取分片权重的元数据信息
- 用于分片信息的传输和存储
1.3 sharded_tensor的关键组件
class ShardedTensor:
"""
Represents a local shard of a distributed tensor parameter.
Args:
key (str): The name of the parameter.
local_tensor (Tensor): The local shard of the parameter.
local_shape (Tuple[int, ...]): The shape of the local shard.
global_shape (Tuple[int, ...]): The global logical shape of the parameter.
global_offset (Tuple[int, ...]): The offset of the local shard in the global parameter.
is_flattened (bool, optional): Whether the parameter has been flattened (used in sharding_v2 scenarios). Default is False.
flattened_range (slice, optional): If the parameter is flattened, this indicates the index range of the actual local shard within the local_tensor.
"""
def __init__(
self,
key: str,
local_tensor: Tensor,
local_shape: tuple[int, ...],
global_shape: tuple[int, ...],
global_offset: tuple[int, ...],
is_flattened: bool = False,
flattened_range: slice | None = None,
) -> None:
self.key = key
self.local_tensor = local_tensor
self.local_shape = local_shape
self.global_shape = global_shape
self.global_offset = global_offset
self.is_flattened = is_flattened
self.flattened_range = flattened_range
def __str__(self) -> str:
"""Returns a formatted string representation of the sharded tensor."""
return (
f"ShardedTensor(\n"
f" key={self.key},\n"
f" local_tensor={type(self.local_tensor).__name__}(shape={self.local_tensor.shape}),\n"
f" local_shape={self.local_shape},\n"
f" global_shape={self.global_shape},\n"
f" global_offset={self.global_offset},\n"
f" flattened_range={self.flattened_range}\n"
f")"
)
def shard_weight(
key: str,
weight: Tensor,
axis: int,
group: Group,
) -> ShardedTensor:
"""Creates a ShardedTensor by splitting the input tensor along a specified axis.
Args:
key: Unique identifier for the tensor.
weight: The input tensor to be sharded.
axis: The axis along which to shard the tensor.
group: The process group used for distributed communication.
Returns:
A ShardedTensor representing the local portion of the global tensor.
"""
if axis < 0 or axis >= len(weight.shape):
raise ValueError(
f"Shard axis {axis} is invalid for tensor with shape {weight.shape}"
)
# Get hybrid communication group and rank information
hcg = fleet.get_hybrid_communicate_group()
current_rank = group.rank
world_size = group.nranks
# Calculate shapes and offsets
local_shape = weight.shape
global_shape = deepcopy(local_shape)
global_shape[axis] = local_shape[axis] * world_size
global_shape = tuple(global_shape)
local_shape = tuple(local_shape)
global_offset = [0] * len(global_shape)
if world_size > 1:
global_offset[axis] = current_rank * local_shape[axis]
global_offset = tuple(global_offset)
return ShardedTensor(
key=key,
local_tensor=weight,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
)
def build_sharded_state_dict(
state_dict: dict[str, Tensor],
shard_rules: dict[str, int] | None = None,
prefix: str = "",
) -> dict[str, ShardedTensor]:
"""Converts a regular state dict to a sharded state dict based on sharding rules.
Args:
state_dict: The original state dictionary containing tensors
shard_rules: Dictionary mapping tensor names to their sharding axes.
If None, treated as empty dict (no tensor parallelism).
prefix: Optional prefix to prepend to all tensor keys
Returns:
Dictionary with the same keys as input but values converted to ShardedTensor
or regular Tensor based on sharding rules.
Note:
Tensors not in shard_rules will be wrapped as non-sharded ShardedTensors.
"""
shard_rules = shard_rules or {}
sharded_state_dict = {}
for key, tensor in state_dict.items():
full_key = f"{prefix}{key}" if prefix else key
if key in shard_rules:
# Apply tensor parallelism sharding
sharded_state_dict[full_key] = (
make_tp_sharded_tensor_for_checkpoint(
key=full_key,
tensor=tensor,
tensor_parallel_axis=shard_rules[key],
)
)
else:
# Create regular sharded tensor (non-tensor-parallel)
sharded_state_dict[full_key] = make_replicated_sharded_tensor(
key=full_key,
tensor=tensor,
)
return sharded_state_dict
主要是ShardedTensor类和build_sharded_state_dict、shard_weight两个接口,ShardedTensor主要是作为后续shard_state_dict中的基础单元,即字典格式(key: ShardedTensor),原来版本是普通的Tensor,而现在的ShardedTensor携带了Tensor切分的信息,主要是local_shape、global_shape、global_offset则可以据此对local_tensor进行全局tensor的重建,再对齐进行reshard。build_sharded_state_dict是在普通的state_dict的基础上,对于需要做分布式处理的(即shard)tensor进行切分标记,将tensor转化为ShardedTensor,make_tp_sharded_tensor_for_checkpoint其实就是做mp参数并行,里面调用的就是shard_weight接口,返回一个ShardedTensor;对于不需要切分的,也要用make_replicated_sharded_tensor处理,将其转化为统一的ShardedTensor类,这部分处理无需调用shard_weight,直接返回ShardedTensor,local_shape=global_shape,因为每个rank上保存的这部分数据都一样。而shard_weight,传入进来的tensor,对应切分的那个维度的数据,每个rank都不一样(对于shard组来说),因此将每个rank上该tensor的对应维度的shape加起来,即可得到global_shape,从而构造出具有分布式信息的Tensor。
1.4 load_state_dict的关键组件
1.4.1 get_rank_to_files(与原来的一致)
ef get_rank_to_files(
metadata_list,
local_data_files,
state_dict,
process_group,
use_dist,
mw_name_compatibility=True,
):
"""
Get the mapping of rank to its accessible files.
"""
# The necessary files to be read
tensor_key_list = []
necessary_files = []
mw_name_compatibility_mapping = {}
for metadata in metadata_list:
for local_tensor_index, file_name in metadata.storage_metadata.items():
assert (
local_tensor_index not in tensor_key_list
), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata."
tensor_key_list.append(local_tensor_index.tensor_key)
if local_tensor_index.tensor_key in state_dict:
necessary_files.append(file_name)
all_necessary_files = []
if use_dist:
paddle.distributed.all_gather_object(
all_necessary_files, necessary_files, process_group
)
else:
all_necessary_files.append(necessary_files)
global_necessary_files = [
file for files in all_necessary_files for file in files
]
global_necessary_files_set = set(global_necessary_files)
if len(global_necessary_files_set) <= 0:
logger.warning(
"No necessary data files found in the checkpoint directory. Please check the metadata."
)
missing_keys = set(state_dict.keys())
return {}, missing_keys, mw_name_compatibility_mapping
# allgather all accessible files
global_data_files = []
if use_dist:
paddle.distributed.all_gather_object(
global_data_files, local_data_files, process_group
)
else:
global_data_files.append(local_data_files)
tmp = []
for files in global_data_files:
tmp += files
global_data_files_set = set(tmp)
logger.debug(
f"necessary_data_files_set:{global_necessary_files_set}, global_data_files_set:{global_data_files_set}"
)
# check necessary files in global_data_files
assert (
global_data_files_set & global_necessary_files_set
== global_necessary_files_set
), f"The checkpoint files are not complete. Please check the checkpoint directory. global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}"
missing_keys = set(state_dict.keys()) - set(tensor_key_list)
if len(missing_keys) > 0:
if mw_name_compatibility:
mw_name_compatibility_mapping = _modify_mw_name_for_compatibility(
state_dict, missing_keys, tensor_key_list
)
if len(missing_keys) > 0:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
else:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
rank_to_files = {}
for rank, need_files in enumerate(all_necessary_files):
seen = set()
unique_need_files = [
f for f in need_files if not (f in seen or seen.add(f))
]
rank_to_files[rank] = unique_need_files
logger.debug(f"mapping rank_to_files:{rank_to_files}")
return rank_to_files, missing_keys, mw_name_compatibility_mapping
根据保存的storage_metadata,遍历当前rank上的state_dict,根据local_tensor_index.tensor_key是否在state_dict中,来确定是否需要当前local_tensor_index对应的文件,如果需要就添加到necessary_files中,all_necessary_files保存的是所有rank的necessary_files,如下:
all_necessary_files = [
["0_0.distcp", "1_0.distcp"], # rank 0 需要的文件
["2_0.distcp", "3_0.distcp"], # rank 1 需要的文件
["4_0.distcp", "5_0.distcp"], # rank 2 需要的文件
["6_0.distcp", "7_0.distcp"], # rank 3 需要的文件
]
即key就是rank id,value就是该rank需要的文件列表,seen是用来去重的。
1.5 paddlenlp适配
结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。
为什么只有 LlamaLMHead 和 PipelinePretrainedModel 需要适配?
1.LlamaLMHead
-
LM Head 的权重轴不固定:
LlamaLMHead支持transpose_y和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:-
当
transpose_y=True且tie_word_embeddings时,weight形状是[vocab_size, hidden_size],切分轴应为axis=0。 -
否则通常是
[hidden_size, vocab_size],切分轴应为axis=1。 -
代码中专门计算了
axis = 0 if self.transpose_y else 1,然后:# L2000-L2006
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。
-
-
其它模块已有分片实现或可用默认机制:
- 注意力/MLP里用的
ColumnParallelLinear、RowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。 - 词嵌入
VocabParallelEmbedding也在并行库里有自己的分布式属性与导出路径。
- 注意力/MLP里用的
-
LM Head 还涉及权重共享与并行输出:
tie_word_embeddings时和Embedding共享权重,且is_distributed/split_axis被设置用于张量并行。- 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时出错。
2.PipelinePretrainedModel
结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其 它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。
-
LM Head 的权重轴不固定:
LlamaLMHead支持transpose_y和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:-
当
transpose_y=True且tie_word_embeddings时,weight形状是[vocab_size, hidden_size],切分轴应为axis=0。 -
否则通常是
[hidden_size, vocab_size],切分轴应为axis=1。 -
代码中专门计算了
axis = 0 if self.transpose_y else 1,然后:# L2000-L2006
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。
-
-
其它模块已有分片实现或可用默认机制:
- 注意力/MLP里用的
ColumnParallelLinear、RowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。 - 词嵌入
VocabParallelEmbedding也在并行库里有自己的分布式属性与导出路径。
- 注意力/MLP里用的
-
LM Head 还涉及权重共享与并行输出:
tie_word_embeddings时和Embedding共享权重,且is_distributed/split_axis被设置用于张量并行。- 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时 出错。
2.对相关的分布式API添加shard_state_dict处理
2.1 VocabParallelEmbedding
2.1.1 接收的输入
文本输入
用户输入: "Hello world, how are you?"
分词(Tokenization)
分词结果: ["Hello", "world", ",", "how", "are", "you", "?"]
词汇表映射(Vocabulary Mapping)
词汇表: {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3,
"Hello": 4, "world": 5, ",": 6, "how": 7, "are": 8, "you": 9, "?": 10, ...}
映射结果: [4, 5, 6, 7, 8, 9, 10]
输入到模型为词汇ID序列
模型接收的输入: x = [4, 5, 6, 7, 8, 9, 10] (词汇ID序列)
因此,VocabParallelEmbedding接收到的输入x是[batch_size,seqlenth],即多组词汇ID序列。
2.1.2 处理输入
假设vocab_size=50000,embedding_dim=1024,即有50000个词,映射成向量用1024个特征表示,每个词对应一个1024长度的特征向量:
每个词汇ID对应矩阵中的一行:
word_id=0 -> W[0, :] = [0.1, 0.2, 0.3, ..., 0.1024]
word_id=1 -> W[1, :] = [0.5, 0.1, 0.8, ..., 0.2048]
word_id=2 -> W[2, :] = [0.3, 0.7, 0.2, ..., 0.3072]
...
word_id=499999 -> W[499999, :] = [0.9, 0.4, 0.6, ..., 0.1024]
输入为:
# 输入: x = [batch_size, seq_len] (词汇ID)
# 例如: x = [[100, 250000, 500000, 750000],
# [150, 250100, 500100, 750100]]
Vocab分割后:
GPU0: W[0:250000, :] (250000行,1024列)
GPU1: W[250000:500000, :] (250000行,1024列)
GPU2: W[500000:750000, :] (250000行,1024列)
GPU3: W[750000:1000000, :] (250000行,1024列)
并行化后的查找过程:
对于输入词汇ID,每个GPU的处理:
GPU0 (负责词汇0-249999):
- 输入ID=100: 查找 W[100, :] = [0.1, 0.2, ..., 0.1024]
- 输入ID=150: 查找 W[150, :] = [0.3, 0.4, ..., 0.1024]
- 输入ID=250000: 不在范围内,返回零向量或特殊处理
- 输入ID=500000: 不在范围内,返回零向量或特殊处理
GPU1 (负责词汇250000-499999):
- 输入ID=100: 不在范围内,返回零向量
- 输入ID=250000: 查找 W[250000, :] = [0.5, 0.6, ..., 0.1024]
- 输入ID=250100: 查找 W[250100, :] = [0.7, 0.8, ..., 0.1024]
- 输入ID=500000: 不在范围内,返回零向量
GPU2 (负责词汇500000-749999):
- 输入ID=500000: 查找 W[500000, :] = [0.9, 0.1, ..., 0.1024]
- 输入ID=500100: 查找 W[500100, :] = [0.2, 0.3, ..., 0.1024]
GPU3 (负责词汇750000-999999):
- 输入ID=750000: 查找 W[750000, :] = [0.4, 0.5, ..., 0.1024]
- 输入ID=750100: 查找 W[750100, :] = [0.6, 0.7, ..., 0.1024]
最终将每张卡的结果做allreduce合并,则得到最终结果,输出为:[batch_size, seq_len, embedding_dim]。
一开始该层权重是随机初始化的,即,每个词虽然 都用向量表示,但此时是无意义的,经过训练后,相近的词embedding的数据会逐渐相似,从而在推理时,正确找到每个词的embedding。
2.2 ColumnParallelLayer与RowParallelLayer同时使用的关系
2.2.1 ColumnParallelLayer

2.2.2 RowParallelLayer

可以看到,RowParallelLayer在计算的过程中,需要把输入拆分成两列分别在两张卡上做计算,最终两张卡都得到Parital状态的数据,而如果上一层是ColumnParallel则其计算的结果刚好分配到两个设备上(即结果被按列切分),而此结果正是RowParallelLayer需要的输入,那么就无需做通信,直接继续计算最后再做allreduce即可。
2.2.3 ColumnParallelLayer与RowParallelLayer的w和bias的切分

注意,在做y=x*W^T+b的计算时,首先乘积得到的数据是[batchsize,output_size],每一行表示一个数据,而bias是分别和每一行相加,因此bias是一个一维的向量,因此,当W按列切分时,bias需要按行切分,从而保持正确的计算关系。
当添加了bias的时候,做RowParallelLayer和ColumnParallelLayer情况如下:
RowParallelLayer:

RowParallelLayer只切w,不切bias
ColumnParallelLayer:

ColumnParallelLayer切w的axis=1,切bias的axis=0
2.3 DygraphShardingOptimizerV2
核心目标
sharded_state_dict 是为了解决不同并行策略间状态转换的问题,以及V2情境下,optimizer被展开铺平的问题:
- 例如从 tp2 切换到 tp4:需要重新划分参数
- 保持数据完整性:确保参数和优化器状态正确转换
- 支持断点续训:在不同并行配置间无缝切换
实现方法总结
1. 分片信息收集阶段
# 第一步:收集当前分片策略的信息
for comm_group, buffers in comm_group_buffers.items():
for buffer in buffers:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
# 记录每个参数在当前rank的分片范围
param_slice_info[param_name] = (
grad_view._param_begin, # 分片起始位置
grad_view._param_end, # 分片结束位置
)
# 记录参数的完整形状信息
param_shape_info[param_name] = (
grad_view._param.shape, # 原始形状
grad_view._param.numel().item(), # 元素总数
grad_view._index, # 分片索引
grad_view._padded_size, # 填充大小
)
目的:记录当前分片策略下每个参数如何被分配到各个rank。
2. 全局信息同步阶段
# 第二步:收集所有rank的分片信息
for comm_group, buffers in comm_group_buffers.items():
# 从当前rank收集信息
param_slice_info["sharding_rank"] = comm_group.rank
# 通过all_gather收集所有rank的信息
gathered_info = []
paddle.distributed.all_gather_object(
gathered_info, param_slice_info, group=comm_group
)
all_rank_slice_info.extend(gathered_info)
目的:让每个rank都知道完整的分片分布情况,为后续重建做准备。
3. 部分分片张量识别阶段
# 第三步:识别哪些张量是部分分片的
for param_key, tensor in optim_state_dict.items():
base_name, _ = _generate_base_static_name(param_key)
if int(tensor.numel()) > 1: # 非标量张量
begin, end = merged_slice_info[base_name]
shape_info = merged_shape_info[base_name]
# 判断是否为部分分片:分片大小 < 原始大小
if shape_info and end > begin and end - begin < shape_info[1]:
partial_tensor_names.append(base_name)
目的:区分完全分片和部分分片的张量,它们需要不同的处理策略。
4. 偏移映射计算阶段
# 第四步:计算每个rank在完整张量中的偏移位置
for tensor_name in partial_tensor_names:
offset_mapping[tensor_name] = [0] * world_size
# 记录每个rank的分片大小
for info in all_rank_slice_info:
if tensor_name in info:
begin, end = info[tensor_name]
if end > begin:
offset_mapping[tensor_name][info["sharding_rank"]] = end - begin
# 转换为累积偏移
running_total = 0
for rank in range(world_size):
current_size = offset_mapping[tensor_name][rank]
offset_mapping[tensor_name][rank] = running_total
running_total += current_size
目的:为每个rank计算其在完整张量中的起始位置,用于重建完整张量。
5. 状态字典构建阶段
# 第五步:构建分片状态字典
for param_key, tensor in optim_state_dict.items():
base_name, optim_state_type = _generate_base_static_name(param_key)
struct_name = static_to_struct[base_name]
sharded_param = model_sharded_state_dict[struct_name]
unified_name = f"{struct_name}.{optim_state_type}"
# 处理三种不同类型的张量
if int(tensor.numel()) == 1:
# 标量参数:直接保存
sharded_weight = ShardedWeight(...)
elif base_name in partial_tensor_names:
# 部分分片张量:记录在完整张量中的位置
flattened_offset = offset_mapping[base_name][sharding_rank]
sharded_weight = ShardedWeight(
flattened_range=slice(flattened_offset, flattened_offset + int(tensor.numel()))
)
else:
# 完全分片张量:当前rank拥有完整分片
sharded_weight = ShardedWeight(
flattened_range=slice(0, int(tensor.numel()))
)
目的:为每个优化器状态创建包含完整分片信息的 ShardedWeight 对象。
关键设计思想
1. 分层信息记录
# 记录三个层次的信息:
# 1. 参数级:param_slice_info - 分片范围
# 2. 形状级:param_shape_info - 完整形状
# 3. 全局级:offset_mapping - 全局偏移
2. 分类处理策略
# 三种处理策略:
# 1. 标量参数:直接保存,无需分片信息
# 2. 部分分片张量:记录在完整张量中的位置
# 3. 完全分片张量:当前rank拥有完整分片
3. 全局视角构建
# 每个rank都收集全局信息:
# 1. 所有rank的分片范围
# 2. 完整的参数形状
# 3. 全局偏移映射
2.4 SP(序列并行)
与ColumnParallel、RowParallel类似,只是维度发生在seq_len,且伴随tp(mp)使用。
2.5 关于shared_state_dict方法中structured_name_prefix为空的问题

实际上在这里会递归调用sub_layer的shared_state_dict方法,从而将当前层的name传递到sub_layer作为前缀。
3.测试Ernie中的一些问题
1.self.args.offload_optim
_offlad_optimizer导致保存的转换后的optimizer.pdopt中的动量都没保存成功
模型转换时,如dp2->dp4,offload_optimizer处理后,此时state_dice()中只有master_params和shceduler的数据,动量都被卸载到cpu上了,导致保存失败。

2.sharding4转纯dp2时,文件名不对应,无法加载

shading对应保存的文件名是model_state_shardxx,但是纯dp加载的model文件名是model_state.paparams,因此保存shading4的ckpt,而此时换成纯dp2训练时,无法正确加载ckpt文件。
3.checkpoint文件路径问题

- checkpoint文件路径修改后,paddleformers得同步更新导入得load_state_dict和save_state_dict
4.MoElayer找不到config属性

原因:

这块是因为还没定义就用了那个config的一些参数,我直接给注释掉了。
5.纯sharding出错,原因是在梯度累加时累加的数据类型有问题


注释掉的为原来的代码,然而测试了最新的paddle发现,add_已经适配了fp16和float32两个不同精度的数相加的场景,估计是当时合入的pr造成的bug,已经被修复。
6.t2(ep2)->pp4,报源ckpt加载后的数据,缺少某个参数的优化器状态
lm_head与embedding共享一份weight,因此优化器内部的优化器状态也只有一份
这会导致,在加载ckpt的时候,报错:

主要原因

在加载ckpt时,需要初始化model和opt,而使用flex_ckpt框架时,对应的init_opt中是根据model里面的每个key来创建对应的opt状态,,此时embed_layer和lm_head_layer共用同一个参数,因此优化器状态只有一份,所以导致在ckpt中找不到初始化时创建的embed_tokens,导致报错。
问题追溯:
打印出的model及其对应的value:


可以看到都指向同一个tensor,而具体实现在tie_weight:



可以看到在这里面将lm_head 直接赋值为embedding对应的tensor
为什么共用一个参数,他们也共用一份优化器状态?




以上是创建optimizer涉及到得流程,可以看到,optimizer中包含的参数,是根据params来去重的,即直接根据Tensor去重,而不是key,因此共享tensor的参数,只会有一份保留在optimizer的参数列表中,并且是第一次出现的参数。



在创建累加器时,此时只有{key:embedding_0.w_0, shared_tensor}保留下来了,所以只有embedding的优化器状态创建了,就不会再创建lm_head的了,打印出来如下:
为什么报错提示找不到embedding的优化器状态,而不是lm_head的优化器状态?

在这里,因为两者的v.local_tensor.name一致,前者被覆盖了。
问题总结:
tp2(ep2)->pp4问题总结: 遇到的问题: 在pp4 load tp2(ep2)保存的ckpt时,加载AOAEngine,调用shape_propagation函数时,未被AOA规则改写的参数会做补全映射,而此时会判断补全的这个key是否在源策略(tp2ep2)中出现过,若没出现过则会报错,而此处就报错:找不到 ernie.embed_tokens.weight.moment1_0(其实所有的embed_tokens.weight相关的优化器状态都找不到)。 原因总结: 在初始化opt的函数中即init_optmizer();会根据当前加载的model参数初始化优化器状态,每一个参数都会为其创建优化器状态,而在ernie4.5非pp的组网中,会使用tie_weight函数使得lm_head与embed相关的两个参数共享同一份tensor,而在训练tp2(ep2)创建优化器状态时,相同param.name的param,只会创建一份优化器状态,并且以第一次出现的key来创建优化器状态参数对应的名称,因此确实lm_head与embed仅仅只有一份权重才对,因此此处是需要优化init_opt部分的逻辑。 然而针对上述逻辑,最终应该是能找到embed相关的优化器状态,而找不到lm相关的优化器状态才对,经过查证,问题在于,AdamW的sharded_state_dict在创建static_to_struct_mapping映射时,未对共用同一个tensor的参数做判断,导致对于共享同一个weight的layer来说,后面layer的参数名(即key)会把前面layer的参数名给覆盖,lm_head在后面,因此覆盖了embed,导致我们在优化器看到的是只有lm_head的优化器状态。因此这里需要优化的是,dygraph_sharding_optimizer和AdamW内的sharded_state_dict函数的逻辑。 但针对ernie的pp组网,查证后发现,并未支持tie_weights操作,lm_head和embed分别独立一份weight;而非pp组网,默认一定调用tie_weights操作,因此在当前情况下,无法做tp2(ep2)->pp4的转换。
7.bias开false时,会遇到报错

主要原因是,这里直接对bias做scale,然而当bias为None时,是无法做scale的,导致出错。

做如下修改即可:

8.测tp2(ep2)->tp4,有一个参数的md5未对齐
最终总结:FLAGS_shard_bypass_dygraph_optimizer 标志位只能控制优化器本身的参数更新,但无法阻止回调函数中的直接参数操作更新。在 on_optimizer_end 阶段,OrthogonalCallback 会计算正交损失并直接更新 ernie.layers.1.mlp.gate.weight 参数,这种直接参数修改绕过了优化器控制机制,导致该参数在 save/load 转换过程中被意外更新,从而造成 MD5 校验失败。证明了我们的FlexCheckpoint框架逻辑没问题。

9. 测tp2(ep2)->pp4,有多个参数的md5未对齐

发现是因为moe模式下,开了的原因,导致moe模式下的模型,会多出一个moe_statics.e_score_correction_bias的参数。
10.tp2(ep2)转vpp4,暂时有问题,num_hidden_layer配8层,9层都不对

主要是:multi_token_pred_depth参数开启时,会在hidden_layer中多加一层MTP层,导致以下断言不支持,当前ernie4.5的vpp不支持加MTP这一层:
assert sum(weights) % actual_num_parts == 0
但是关掉multi_token_pred_depth参数后,又报p2p通信错误,看起来是ernie4.5跑vpp自身的bug:

11.tp2(ep2)转tp2(ep2)+sd2 和 dp2转tp2(ep2)+sd2时,会出现,某些优化器状态消失得现象

也是开了opitimizer_offload的原因
12.DP2转DP4

在ernie下跑会报错,因为opt没被封装,它没有inner_opt,而llama下跑不会报错,因为llama下封装了一层:
paddle.distributed.fleet.utils.mix_precision_utils.MixPrecisionOptimizer
因此需要加个判断:
inner_opt = getattr(optimizer, "_inner_opt", None)
if DygraphShardingOptimizer is not None and isinstance(inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return
elif DygraphShardingOptimizerV2 is not None and isinstance(inner_opt, DygraphShardingOptimizerV2):
注意:ernie4.5训练时,train函数中调用的self._wrap_model是/home/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py内的方法。
13.DP2转Sharding4_V1的时候(开dp_group和sharding_group)(验证纯策略需要开dummy)

总共64个专家,2卡sharding的时候,只有32个专家有优化器状态,4卡sharding的时候只有16个专家有优化器状态;原因是每个rank上的experts组在训练一次后,所有专家的参数被同步了,未具体定位,但训练5步,4张卡上的16个专家参数的md5完全相同;但初始化时,4张卡上的专家参数是不同的。

每张卡都做了同样的操作,deepcopy fn,而对每个rank来说,这个fn都是同一个layer。
而DP的时候,每个rank

14.DP2转Sharding4_V2的时候
报错如下:

15.DP2转Sharding2_EP2
报错如下,主要是在sharding2_ep2转回dp2时报错,但是接续loss 1E-5对齐:

其实是开了这个的原因
16.DP2、ShardingV1、V2转TP2(EP2)_PP2
存在下面的报错:

主要原因是ernie4.5中的SequenceParallelLayer没有适配sharded_state_dict
![]()
17.TP2(EP2)跑num_hidden_layer=3的情况时,md5未对齐
